# ------------------------------------------------------------------------------
# Calculates correlation between features in stent/test models
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q medium -R "rusage[mem=1000]" bash 07_calc_feature_outcome_cor.sh
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(here)
library(yaml)
library(tidyverse)
library(glue)
library(Matrix)
library(glmnet)

temp <- here(
  "code", "06_physician_boundedness", "01_behavioral_lasso", "temp"
)

# Helpers ----------------------------------------------------------------------
score_from_filename <- function(x) {
  str_replace_all(str_extract(x, "[a-z]+_010_day"), "-", "_")
  score <- case_when(
    grepl("stent", x) ~ "stent_or_cabg_010_day",
    grepl("test", x) ~ "test_010_day"
  )
  return(score)
}

extract_coef_names <- function(which_score, which_lambda, models) {
  model_coef <- coef(models[[which_score]], s = which_lambda)
  coef_names <- rownames(model_coef)[which(model_coef != 0)]
  setdiff(coef_names, "(Intercept)")
}

identify_selected <- function(which_score, which_lambda, models) {
  model_coef <- coef(models[[which_score]], s = which_lambda)
  selection_tb <- tibble(
    feature = rownames(model_coef),
    selected = ifelse(model_coef == 0,
      "excluded", "included"
    )
  )
  return(selection_tb)
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
overnight_lab <- ""
paths <- read_yaml(here::here("lib", "filepaths.yml"))
split <- "random"
x <- readRDS(glue(paths$features$test))
ids <- readRDS(glue(paths$analysis$test_cohort))

performance_raw <- readRDS(file.path(temp, "performance_obs.rds"))

lasso_files <- list.files(temp) %>%
  str_subset("lasso_") %>%
  str_subset(".rds")
models <- lasso_files[!grepl("scores", lasso_files)] %>%
  setNames(score_from_filename(.)) %>%
  map(~ readRDS(file.path(temp, .x)))

# ------------------------------------------------------------------------------
message("Exclusions...")
keep_obs <- which(
  !ids$exclude
)
ids <- ids[keep_obs, ]
x <- x[keep_obs, ]

performance <- filter(performance_raw, model_name == "lasso")

# Calculate Correlation with Outcomes ------------------------------------------
message("Calculating univariate correlations with outcomes...")
cor_tb <- tibble(
  feature = colnames(x),
  cor_test = apply(
    x, 2, cor,
    ids$test_010_day
  ),
  cor_stent_or_cabg = apply(
    x[ids$test_010_day, ], 2, cor,
    ids$stent_or_cabg_010_day[ids$test_010_day]
  )
)

# Identify Selected Variables --------------------------------------------------
message("Identifying lasso-selected variables")
selection_design <- performance %>%
  select(score_name, lambda) %>%
  unique()

selection_tb <- selection_design %>%
  mutate(tb = map2(score_name, lambda, identify_selected,
    models = models
  )) %>%
  unnest(tb)

# Corr Table -------------------------------------------------------------------
message("Getting correlation coef table...")
# score_name is the var used to train the lasso
# target_name is the outcome var against which the score is evaluated for auc
# taking only mismatched score/value, so
# score == "stent_or_cabg_010_day" --> evaluating stent lasso against test
# score == "test_010_day" --> evaluating test lasso against stent
coef_cor_tb <- performance %>%
  # filter(measure_name == "auc",
  filter(
    measure_name == "r2",
    score_name != target_name
  ) %>%
  inner_join(selection_tb, by = c("score_name", "lambda")) %>%
  inner_join(cor_tb, by = "feature") %>%
  transmute(score_name,
    n_coef,
    auc = performance,
    feature,
    selected,
    cor_stent_or_cabg,
    cor_test
  )

# ------------------------------------------------------------------------------
write_csv(coef_cor_tb, file.path(temp, "feature_outcome_cor_by_selection_r2.csv"))

# ------------------------------------------------------------------------------
message("Done.")
